Skip to content

Fix BART attention fusion for SDPA pattern from transformers >= 4.49#27458

Merged
tianleiwu merged 2 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/bart-sdpa-attention-fusion
Feb 27, 2026
Merged

Fix BART attention fusion for SDPA pattern from transformers >= 4.49#27458
tianleiwu merged 2 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/bart-sdpa-attention-fusion

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • Add SDPA-aware pattern matching to FusionBartAttention so that BART attention fusion succeeds on models exported with HuggingFace Transformers >= 4.49
  • Add a synthetic BART SDPA graph generator and unit test

Motivation

Fixes #23864

HuggingFace Transformers >= 4.49 replaced BartAttention with BartSdpaAttention (commit 2c47618), changing the ONNX export graph topology in several ways that broke FusionBartAttention pattern matching. Running optimize_model(..., model_type="bart") on these newer exports produces zero fused Attention nodes.

Changes

fusion_bart_attention.py

The SDPA refactor introduces four structural changes to the attention subgraph. Each required a new match path:

  1. QKV output path — LayerNormalization anchor fallback
    For SDPA models, symbolic shape inference often fails, which prevents SkipLayerNormalization fusion. When the anchor node is a plain LayerNormalization instead of SkipLayerNormalization, there's an extra residual Add between the LayerNorm and the attention output projection. Added a fallback match: ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"] with [0, None, 0, 0, 0, 0].

  2. QK path — NaN guard (Where + IsNaN)
    SDPA wraps the Softmax output in a NaN guard: Where(IsNaN(softmax), 0.0, softmax). The Where node's input[2] is the Softmax output. Added two new QK paths:

    • No mask: ["Where", "Softmax", "MatMul"] with [0, 2, 0]
    • With mask: ["Where", "Softmax", "Add", "MatMul"] with [0, 2, 0, 0]
  3. Q and K scaling paths
    Instead of a single combined scale on the QK MatMul output, SDPA applies separate Mul(1/sqrt(head_dim)) to Q and K before the QK MatMul. Added:

    • Q path: ["Mul", "Transpose", "Reshape", "Add", "MatMul"] with [0, 0, 0, 0, None]
    • K path: ["Mul", "Reshape", "Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"] with [1, 0, 0, 0, 0, 0, 0, None] (K^T uses a Reshape→Transpose(0,2,1)→Reshape chain)
  4. num_heads fallback for dynamic shapes
    SDPA models use -1 in reshape shape tensors for dynamic dimensions, causing get_num_heads_and_hidden_size to return negative values. Added a fallback to user-specified num_heads/hidden_size when detected values are invalid.

bart_model_generator.py (new)

Synthetic BART SDPA attention graph generator that builds a minimal but complete attention subgraph matching the SDPA topology. Tests both with_mask=True (decoder self-attention) and with_mask=False (encoder attention) variants.

test_attention_fusion.py

Added test_bart_attention_sdpa_fusion that verifies:

  • 1 Attention node is produced for each mask variant
  • Correct num_heads attribute
  • Correct unidirectional attribute (1 for decoder self-attention with mask, 0 for encoder)

Test Plan

  • python -m pytest test_attention_fusion.py -v — all 10 tests pass
  • lintrunner on all 3 changed files — no issues
  • Verified on real exported BART SDPA model (hf-internal-testing/tiny-random-bart): 2 Attention nodes fused, graph reduced from 120 → 34 nodes

HuggingFace Transformers >= 4.49 replaced BartAttention with
BartSdpaAttention, changing the ONNX graph topology in several ways
that broke FusionBartAttention pattern matching. This adds SDPA-aware
match paths so that BART attention fusion succeeds on modern exports.
@Rishi-Dave Rishi-Dave force-pushed the rishidave/fix/bart-sdpa-attention-fusion branch from 982b168 to fe4dfce Compare February 25, 2026 17:59
Comment thread onnxruntime/test/python/transformers/bart_model_generator.py Dismissed
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Comment thread onnxruntime/python/tools/transformers/fusion_bart_attention.py
Comment thread onnxruntime/python/tools/transformers/fusion_bart_attention.py
tianleiwu
tianleiwu previously approved these changes Feb 27, 2026
…lback

- Document why mask presence is derived from the QK pattern match result
  rather than re-walking the graph (line 352 feedback).
- Add logger.debug when num_heads/hidden_size falls back to user-specified
  values, logging both detected and fallback values (line 410 feedback).
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu enabled auto-merge (squash) February 27, 2026 05:30
@tianleiwu tianleiwu merged commit 028f88c into microsoft:main Feb 27, 2026
86 of 90 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Attention fusion broken for BART 🤖

3 participants